import time
from collections import Counter
import matplotlib

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib.colors import ListedColormap
from sklearn.manifold import TSNE
from mpl_toolkits.mplot3d import Axes3D
from torch.utils.data import WeightedRandomSampler
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm

from continual_rl.policies.arc.arc_buffer import SelfSupervisedBuffer
from continual_rl.policies.arc.arc_nets import ActionEncoder, ActionDecoder, ARCImpalaNet
from continual_rl.policies.impala.torchbeast.monobeast import Monobeast
from continual_rl.utils import utils
from continual_rl.utils.arc_visual import action_names, cmap
import logging
import warnings

# 禁用matplotlib的日志输出
logging.getLogger('matplotlib.font_manager').disabled = True
# 禁用多余的警告
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)


class ARCMonobeast(Monobeast):
    def __init__(self, model_flags, observation_spaces, action_spaces, policy_class):
        super().__init__(model_flags, observation_spaces, action_spaces, ARCImpalaNet)  # 使用输出为表征，而不是动作的网络

        assert not model_flags.use_lstm, "AR does not presently support using LSTMs."

        self._task_run_id = 0  # 默认第一个运行任务的编号为0
        self._sl_buffer = SelfSupervisedBuffer(model_flags.sl_buffer_items)
        self.sl_buffer_toggle = False  # 是否开启自监督学习缓冲区的数据收集
        # self._unroll_length = model_flags.explore_unroll_length  # 设置展开长度为探索时的展开长度，在Actor进程启动时会传入act方法

        # try:  # ~ for quickly testing
        #     self._sl_buffer.load("sl_buffer.pkl")
        #     self.logger.info("加载了自监督学习缓冲区")
        # except Exception as e:
        #     self.logger.warning("未加载自监督学习缓冲区" + str(e))
        # try:
        #     self._load_encoder_decoder("test")
        #     self.logger.info("加载了自监督学习模型")
        # except Exception as e:
        #     self.logger.warning("未加载自监督学习模型" + str(e))

        # 自监督学习模型
        self.action_encoder = ActionEncoder(self.max_observation_space.shape, model_flags.embedding_size).to(
            model_flags.device)

        self.action_decoder, self.sl_optimizer = None, None
        self.setup_decoder()

        # 用于计算EWC损失的前一任务数据
        self.decoder_dataset = None
        self.encoder_dataset = None

        self.actor_is_random = True  # 随机进行探索

        self.representation_buffers = []  # 存储每个任务的行为嵌入及行为标签，用于可视化

    def setup_decoder(self):
        # 根据当前任务的动作空间维度设置解码器的输出维度
        if self.action_decoder is None:
            self.logger.info("设置新的解码器，当前任务的动作空间维度：" + str(self.action_spaces[self._task_run_id].n))
            self.action_decoder = ActionDecoder(self._model_flags.embedding_size,
                                                self.action_spaces[self._task_run_id].n).to(
                self._model_flags.device)
        else:
            # 当动作空间增加时，需要重新设置解码器的输出维度
            self.logger.info("扩展解码器的输出维度：" + str(self.action_spaces[self._task_run_id].n))
            self.action_decoder.expand_output(self.action_spaces[self._task_run_id].n, self.decoder_dataset)
            self.action_decoder.to(self._model_flags.device)

        # 重新创建自监督学习模型的优化器，包含编码器和新的解码器参数
        self.sl_optimizer = torch.optim.Adam(
            list(self.action_encoder.parameters()) + list(self.action_decoder.parameters()),
            lr=self._model_flags.learning_rate)

        # if self._task_run_id != 0:  # 当不是第一个任务时，冻结编码器参数
        #     for param in self.action_encoder.parameters():
        #         param.requires_grad = False

        # 设置策略网络中的动作解码器
        self.actor_model.set_action_decoder(self.action_decoder)
        self.learner_model.set_action_decoder(self.action_decoder)

        # 将解码器参数添加到优化器中，即策略学习过程中同时更新解码器
        # self.optimizer.add_param_group({"params": self.action_decoder.parameters()})

        self.learner_model.to(self._model_flags.device)

    def create_buffer_specs(self, unroll_length, obs_shape, num_actions):
        # 配置缓冲区需要的键和每个值的数据类型
        # 此处是num_actions传入的是网络的输出维度，而不是动作空间的维度
        T = unroll_length
        specs = dict(
            frame=dict(size=(T + 1, *obs_shape), dtype=torch.uint8),
            reward=dict(size=(T + 1,), dtype=torch.float32),
            done=dict(size=(T + 1,), dtype=torch.bool),
            episode_return=dict(size=(T + 1,), dtype=torch.float32),
            episode_step=dict(size=(T + 1,), dtype=torch.int32),
            # policy_logits=dict(size=(T + 1, num_actions), dtype=torch.float32), # 没有存储动作解码器的输出
            baseline=dict(size=(T + 1,), dtype=torch.float32),
            uncertainty=dict(size=(T + 1,), dtype=torch.float32),
            last_action=dict(size=(T + 1, num_actions), dtype=torch.float32),  # 存储上一步动作嵌入，因此类型为float32
            action=dict(size=(T + 1, num_actions), dtype=torch.float32),  # 存储动作嵌入，因此类型为float32
            actual_action=dict(size=(T + 1,), dtype=torch.int64),  # 实际上解码后的动作，用于执行行为
        )
        return specs

    def get_policy_logits(self, batch_output):
        # 通过解码器得到动作概率
        embedding = batch_output["action"]
        policy_logits, _ = self.action_decoder(embedding)
        return policy_logits

    @staticmethod
    def preprocess_env_output(env_output, model_flags, agent_output=None):
        '''
        用于处理环境输出，将环境输出转换为模型输入
        '''
        if agent_output is None:
            env_output["last_action"] = torch.zeros(1, model_flags.embedding_size)
        elif "action" in agent_output:
            # 将上一步的动作设置为当前动作嵌入，因为环境给出的上一步动作是解码后的动作
            action = agent_output["action"]
            env_output["last_action"] = action
        # 如果agent_output中没有动作，则表示不需要输入到模型

        return env_output

    def on_act_unroll_complete(self, task_flags, actor_index, agent_output, env_output, new_buffers):
        if not self.sl_buffer_toggle:
            return
        # 取前面n-1个状态作为当前状态，取后面n-1个状态作为下一个状态
        states = new_buffers["frame"][:-1]
        actions = new_buffers["actual_action"][:-1]  # 存储的是解码后的动作
        next_states = new_buffers["frame"][1:]
        self._sl_buffer.add(states, actions, next_states)

        # def learn(self, model_flags, task_flags, actor_model, learner_model, batch, initial_agent_state, optimizer,

    #           scheduler, lock, ):
    #     stat = super().learn(model_flags, task_flags, actor_model, learner_model, batch, initial_agent_state, optimizer,
    #                          scheduler, lock)
    #     # 在每次策略优化后进行监督学习训练
    #     # 缓冲区满的情况下，每次策略更新一步，自监督学习更新多步
    #     stat = self._sl_learn(model_flags, update_steps=model_flags.sl_update_steps, stat=stat)
    #     return stat

    def build_decoder_dataset(self, data_loader, device='cuda'):
        print("构建解码器数据集")
        all_embeddings = []
        all_actions = []
        for states, actions, next_states in data_loader:
            states, actions, next_states = states.to(device), actions.to(device), next_states.to(device)
            embeddings = self.action_encoder(states, next_states)

            all_embeddings.append(embeddings)
            all_actions.append(actions)

        if self.decoder_dataset is not None:
            # 读取旧的数据集并添加新的数据中
            old_embeddings, old_actions = zip(*list(self.decoder_dataset))
            old_embeddings = torch.stack(old_embeddings)
            old_actions = torch.stack(old_actions)
            all_embeddings.append(old_embeddings)
            all_actions.append(old_actions)
        all_embeddings = torch.cat(all_embeddings, dim=0)
        all_actions = torch.cat(all_actions, dim=0)
        self.decoder_dataset = torch.utils.data.TensorDataset(all_embeddings, all_actions)
        self.logger.info("解码器数据集大小：" + str(len(self.decoder_dataset)))

    def build_encoder_dataset(self, data_loader, device='cuda'):
        all_states = []
        all_next_states = []
        all_embeddings = []
        for states, actions, next_states in data_loader:
            states, actions, next_states = states.to(device), actions.to(device), next_states.to(device)
            embeddings = self.action_encoder(states, next_states)

            all_states.append(states)
            all_next_states.append(next_states)
            all_embeddings.append(embeddings)

        if self.encoder_dataset is not None:
            # 读取旧的数据集并添加新的数据中
            old_states, old_next_states, old_embeddings = zip(*list(self.encoder_dataset))
            old_states = torch.stack(old_states)
            old_next_states = torch.stack(old_next_states)
            old_embeddings = torch.stack(old_embeddings)
            all_states.append(old_states)
            all_next_states.append(old_next_states)
            all_embeddings.append(old_embeddings)
        all_states = torch.cat(all_states, dim=0)
        all_next_states = torch.cat(all_next_states, dim=0)
        all_embeddings = torch.cat(all_embeddings, dim=0)
        self.encoder_dataset = torch.utils.data.TensorDataset(all_states, all_next_states, all_embeddings)
        self.logger.info("编码器数据集大小：" + str(len(self.encoder_dataset)))

    def _sl_learn(self, model_flags, epoch, stat={}):
        # 如果缓冲区未满，则不执行自监督学习训练步骤
        if self._sl_buffer.size() < model_flags.sl_buffer_items:
            return stat

        self.logger.info("自监督学习训练")

        # 得到本地缓冲区的数据
        dataset = self._sl_buffer.get_torch_dataset()

        # 统计缓冲区中不同动作的数量
        actions = dataset.tensors[1].numpy()
        action_counts = Counter(actions)
        self.logger.info("训练缓冲区中不同动作的数量：")
        for action, count in action_counts.items():
            self.logger.info(f"动作 {action} 数量：{count}")

        # 计算每个样本的权重
        total_count = len(actions)
        weights = np.zeros(total_count)
        for action, count in action_counts.items():
            action_weight = total_count / count
            weights[actions == action] = action_weight

        # 创建 WeightedRandomSampler
        sampler = WeightedRandomSampler(weights, num_samples=total_count, replacement=True)

        data_loader = torch.utils.data.DataLoader(dataset, batch_size=model_flags.sl_batch_size, sampler=sampler)

        pbar = tqdm(range(int(epoch)))
        for e in pbar:
            epoch_loss = 0
            for i, batch_data in enumerate(data_loader):
                # 执行自监督学习训练步骤
                states, actions, next_states = batch_data
                states = states.to(model_flags.device)
                actions = actions.to(model_flags.device)
                next_states = next_states.to(model_flags.device)

                embedding = self.action_encoder(states, next_states)
                # 仅仅使用当前任务的动作空间中的动作进行动作的自监督学习，因为该阶段探索的动作不会超出当前任务的动作空间
                action_logit, all_action_logit = self.action_decoder(embedding, self.action_spaces[self._task_run_id])

                # -------计算自监督学习损失
                sl_loss = self._sl_loss(action_logit, actions)
                if model_flags.contrastive_alpha > 0:
                    # 使用对比损失
                    contrastive_loss = self._contrastive_loss(embedding, actions)
                    loss = sl_loss + model_flags.contrastive_alpha * contrastive_loss
                    # loss = model_flags.contrastive_alpha * contrastive_loss
                    utils.summary_writer.add_scalar(f"Representation Learning/Contrastive Loss/{self._task_run_id}",
                                                    contrastive_loss.item(), i + e * len(data_loader))
                else:
                    loss = sl_loss

                # -------计算EWC损失
                if model_flags.lambda_ewc > 0 and self.action_decoder.fisher_information is not None:
                    # 当EWC正则化参数大于0时，并且解码器的Fisher信息不为空时（不是第一个任务）
                    # 对动作解码器的参数进行EWC正则化
                    ewc_loss = self.action_decoder.ewc_loss()
                    utils.summary_writer.add_scalar(f"Representation Learning/EWC Loss/{self._task_run_id}",
                                                    ewc_loss.item() * model_flags.lambda_ewc, i + e * len(data_loader))
                    loss += ewc_loss * model_flags.lambda_ewc

                if model_flags.encoder_lambda_ewc > 0 and self.action_encoder.fisher_information is not None:
                    ewc_loss = self.action_encoder.ewc_loss()
                    utils.summary_writer.add_scalar(f"Representation Learning/Encoder EWC Loss/{self._task_run_id}",
                                                    ewc_loss.item() * model_flags.encoder_lambda_ewc,
                                                    i + e * len(data_loader))
                    loss += ewc_loss * model_flags.encoder_lambda_ewc

                # -------计算遗忘损失
                if model_flags.rank_weight > 0 and action_logit.shape[-1] < all_action_logit.shape[-1]:
                    # 表示动作空间减少
                    new_num_actions = action_logit.shape[-1]
                    reduced_action_logit = all_action_logit[:, new_num_actions:]  # 减少的那部分动作的概率
                    max_prob_reduced = torch.max(reduced_action_logit, dim=-1)[0]
                    max_prob = torch.max(action_logit, dim=-1)[0]
                    # 减少的动作的最大概率应该大于其他动作的最大概率, 否则添加损失，从而避免遗忘减少的动作的知识
                    rank_loss = torch.mean(F.relu(max_prob - max_prob_reduced - 0.01))
                    utils.summary_writer.add_scalar(f"Representation Learning/Rank Loss/{self._task_run_id}",
                                                    rank_loss.item() * model_flags.rank_weight,
                                                    i + e * len(data_loader))
                    loss += rank_loss * model_flags.rank_weight

                epoch_loss += loss.item()

                # 更新自监督学习模型参数
                self.sl_optimizer.zero_grad()
                loss.backward()
                self.sl_optimizer.step()

                # 记录损失随着更新次数变化
                utils.summary_writer.add_scalar(f"Representation Learning/Supervised Learning Loss/{self._task_run_id}",
                                                sl_loss.item(),
                                                i + e * len(data_loader))

            epoch_average_loss = epoch_loss / len(data_loader)
            pbar.set_description("SL Loss: %.4f" % epoch_average_loss)  # 每个Epoch的平均损失

        if model_flags.lambda_ewc > 0:
            # 将当前训练数据添加到计算Fisher信息的数据集中
            self.build_decoder_dataset(data_loader, device=model_flags.device)

        # self._sl_buffer.clear()  # 重置缓冲区
        stat["sl_loss"] = epoch_average_loss
        self.logger.info("自监督学习损失：%.4f", stat["sl_loss"])

        # 解码器更新后要更新策略网络中的解码器参数
        self.actor_model.update_action_decoder(self.action_decoder)
        self.learner_model.update_action_decoder(self.action_decoder)
        return stat

    def _sl_loss(self, action_logit, batch_action):
        # 计算自监督学习损失：KL散度
        action_prob = F.softmax(action_logit, dim=-1)
        target_prob = F.one_hot(batch_action, num_classes=action_logit.size(-1)).float()
        kl = F.kl_div(action_prob, target_prob, reduction="batchmean")
        return kl

    def _contrastive_loss(self, embeddings: torch.Tensor, actions: torch.Tensor,
                          temperature: float = 0.5) -> torch.Tensor:
        """
        计算对比损失 (NT-Xent Loss)

        参数:
        embeddings (torch.Tensor): 形状为 (batch_size, embedding_size) 的嵌入向量
        actions (torch.Tensor): 形状为 (batch_size) 的动作标签
        temperature (float): 温度参数，默认值为0.5

        返回:
        torch.Tensor: 对比损失值
        """
        # 计算嵌入向量之间的相似度
        similarity_matrix = F.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)

        # 创建动作标签的掩码
        action_mask = actions.unsqueeze(1) == actions.unsqueeze(0)

        # 计算正样本对的相似度
        positive_mask = action_mask.float()
        positive_similarity = similarity_matrix * positive_mask

        # 计算负样本对的相似度
        negative_mask = ~action_mask
        negative_similarity = similarity_matrix * negative_mask.float()

        # 计算正样本对的损失
        positive_loss = -torch.log(torch.exp(positive_similarity / temperature).sum(dim=1) + 1e-10)

        # 计算负样本对的损失
        negative_loss = torch.log(torch.exp(negative_similarity / temperature).sum(dim=1) + 1e-10)

        # 计算总损失
        loss = positive_loss + negative_loss
        loss = loss.mean()

        return loss

    # def train(self, task_flags):
    # 为了测试减少步数
    # task_flags.total_steps = task_flags.total_steps // 10
    # return super().train(task_flags)

    # def cleanup(self):
    #     # 任务结束前进行可视化行为嵌入
    #     self._action_represent_visual(self.model_flags)
    #     super().cleanup()

    def _action_represent_visual(self, model_flags):
        if self._sl_buffer.size() <= 0:
            self.logger.warning("缓冲区为空，无法进行可视化")
            return

        # 用于可视化行为嵌入
        self.logger.info("可视化缓冲区大小：" + str(self._sl_buffer.size()))
        # 计算缓冲区所有数据的行为嵌入
        buffer_list = self._sl_buffer.get_list()
        states, actions, next_states = zip(*buffer_list)
        states = torch.stack(states).to(model_flags.device)
        next_states = torch.stack(next_states).to(model_flags.device)
        actions = torch.stack(actions)

        # 计算行为嵌入
        embeddings = self.action_encoder(states, next_states).detach().cpu().numpy()

        self.representation_buffers.append((embeddings, actions))  # 添加当前任务的可视化内容到缓冲区
        # 存储缓冲区
        torch.save(self.representation_buffers, "representation_buffers.pkl")
        print(f"---到任务{self._task_run_id}的可视化缓冲区已保存---")

        # 使用 t-SNE 进行降维
        # tsne = TSNE(n_components=2, random_state=0, perplexity=50)
        # embeddings_2d = tsne.fit_transform(embeddings)
        #
        # num_actions = actions.max().item() + 1
        # # 动态调整颜色映射
        # dynamic_cmap = ListedColormap(cmap.colors[:num_actions])
        # # 使用 matplotlib 进行可视化
        # fig = plt.figure(figsize=(7, 6))
        # sc = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=actions, cmap=dynamic_cmap)
        # cbar = plt.colorbar(sc, ticks=range(num_actions))
        # cbar.ax.set_yticklabels(action_names[:num_actions], fontsize=12)
        # # 隐藏坐标轴标签，但保留坐标轴
        # ax = plt.gca()
        # ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False, labelbottom=False,
        #                labelleft=False)
        #
        # plt.savefig(f"action_representations_{self._task_run_id}.png")
        # utils.summary_writer.add_figure(f'Action Representation/Explore/{self._task_run_id}', fig)
        # # plt.show()

        # 三维可视化
        # tsne = TSNE(n_components=3, random_state=0)
        # embeddings_3d = tsne.fit_transform(embeddings)
        #
        # # 创建颜色映射
        # num_actions = actions.max().item() + 1
        # cmap = plt.get_cmap('Pastel2', num_actions)
        # colors = cmap(actions.numpy())
        #
        # # 使用 matplotlib 进行3D可视化
        # fig = plt.figure(figsize=(18, 6))
        #
        # # 创建多个视角
        # angles = [30, 60, 90]
        # for i, angle in enumerate(angles):
        #     ax = fig.add_subplot(1, len(angles), i + 1, projection='3d')
        #     ax.scatter(embeddings_3d[:, 0], embeddings_3d[:, 1], embeddings_3d[:, 2], c=colors)
        #     ax.view_init(30, angle)
        #     ax.set_title(f'View angle: {angle}°')
        #
        # plt.suptitle('Action Representations in 3D')
        # plt.savefig("action_representations_3d.png")
        # plt.show()

        # exit(0)

    def explore(self):
        # if not self.sl_buffer_toggle:
        #     return
        self.shared_params["sl_buffer_toggle"] = True # ~ 通知所有actor进程只进行一次探索

        self.logger.info("开始探索")
        # 策略训练之前先进行探索，并进行自监督学习训练
        for _ in range(32):
            if self._sl_buffer.size() < self._model_flags.sl_buffer_items:
                time.sleep(5)  # 等待缓冲区满
                self.logger.info("等待缓冲区满" + "，缓冲区大小：" + str(self._sl_buffer.size()))

                # 不断初始化free_queue为每个缓冲区的索引，以便Actors可以将数据添加到缓冲区
                while not self.free_queue.empty():
                    self.free_queue.get()
                for m in range(self._model_flags.num_buffers):
                    self.free_queue.put(m)
            else:
                break
        else:
            raise (RuntimeError("探索时自监督学习缓冲区未满，请检查Actors是否正常运行"))

        self._sl_learn(self._model_flags, epoch=self._model_flags.sl_learning_epoch)  # 探索后进行自监督学习训练
        # self._sl_buffer.save("sl_buffer.pkl")  # ~ for quickly testing
        # self._save_encoder_decoder("test")  # ~ for quickly testing

        if self._model_flags.representation_visualize:
            self._action_represent_visual(self._model_flags)  # 可视化行为嵌入

        while not self.free_queue.empty():
            self.free_queue.get()
        for m in range(self._model_flags.num_buffers):
            self.free_queue.put(m)

        # if self._model_flags.representation_visualize:
        #     self._sl_buffer.clear()  # 清空缓冲区, 用新的数据进行可视化
        #     time.sleep(5)  # 等待填充缓冲区用于可视化
        #     while not self.free_queue.empty():
        #         self.free_queue.get()
        #     for m in range(self._model_flags.num_buffers):
        #         self.free_queue.put(m)
        while not self.full_queue.empty():
            self.full_queue.get()

        self.sl_buffer_toggle = False  # ~ 只进行一次探索
        self.shared_params["sl_buffer_toggle"] = False  # ~ 通知所有actor进程只进行一次探索
        self.actor_is_random = False  # 探索结束后关闭随机探索
        self.shared_params["actor_is_random"] = False  # 通知Actor进程关闭随机探索

        self._sl_buffer.clear()  # 清空缓冲区
        # self.shared_params["_unroll_length"] = self._model_flags.unroll_length  # 通知Actor进程恢复原来的展开长度
        self.logger.info("探索结束")

    def _actors_isalive(self):
        actor_statuses = {}
        for i, actor in enumerate(self._actor_processes):
            actor_statuses[i] = actor.is_alive()
        return actor_statuses

    def _save_encoder_decoder(self, file_path):
        torch.save(self.action_encoder.state_dict(), file_path + "_encoder.pth")
        torch.save(self.action_decoder.state_dict(), file_path + "_decoder.pth")

    def _load_encoder_decoder(self, file_path):
        self.action_encoder.load_state_dict(torch.load(file_path + "_encoder.pth"))
        self.action_decoder.load_state_dict(torch.load(file_path + "_decoder.pth"))

    def task_change(self, task_run_id=0):
        """
        当动作空间增加时，需要重新设置动作解码器并探索新的动作空间；
        当动作空间减少时，不需要重新设置动作解码器，只需要更新当前任务的动作空间编号，从而在选择实际动作时截断
        TODO 针对动作空间有增有减的情况，需要通过探索来判断哪些动作增加了，哪些减少了，从而设定解码器的输出维度及动作输出的Mask
        """
        # 比较新的任务动作空间大小和当前任务的动作空间大小
        current_action_space_size = self.action_spaces[self._task_run_id].n
        new_action_space_size = self.action_spaces[task_run_id].n
        self.logger.info("任务动作空间大小：" + str(current_action_space_size) + " -> " + str(new_action_space_size))
        self._task_run_id = task_run_id

        # 启动一次探索
        self.sl_buffer_toggle = True
        self.shared_params["sl_buffer_toggle"] = True  # 需要更新共享参数，防止actor进程不进行探索
        self.actor_is_random = True
        self.shared_params["actor_is_random"] = True

        if self._model_flags.lambda_ewc > 0:
            # 在任务切换后,需要计算动作解码器的Fisher信息用于EWC正则化
            self.action_decoder.compute_fisher_information(self.decoder_dataset, device=self._model_flags.device)

        if new_action_space_size > current_action_space_size:
            # 重新训练策略网络中的解码器
            self.setup_decoder()
